/*
* Copyright 2017 StreamSets Inc.
*
* Licensed under the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.streamsets.pipeline.stage.origin.tcp;
import com.google.common.primitives.Bytes;
import com.streamsets.pipeline.api.OnRecordError;
import com.streamsets.pipeline.api.Record;
import com.streamsets.pipeline.api.Stage;
import com.streamsets.pipeline.api.StageException;
import com.streamsets.pipeline.config.DataFormat;
import com.streamsets.pipeline.lib.parser.net.NetTestUtils;
import com.streamsets.pipeline.lib.parser.net.syslog.SyslogFramingMode;
import com.streamsets.pipeline.lib.parser.net.syslog.SyslogMessage;
import com.streamsets.pipeline.lib.tls.TlsConfigErrors;
import com.streamsets.pipeline.sdk.PushSourceRunner;
import com.streamsets.pipeline.stage.util.tls.TLSTestUtils;
import com.streamsets.testing.NetworkUtils;
import io.netty.bootstrap.Bootstrap;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.ChannelOption;
import io.netty.channel.EventLoopGroup;
import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.nio.NioSocketChannel;
import org.apache.commons.io.Charsets;
import org.junit.Assert;
import org.junit.Test;
import java.io.File;
import java.io.IOException;
import java.nio.charset.Charset;
import java.security.KeyPair;
import java.security.cert.Certificate;
import java.util.Arrays;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.UUID;
import java.util.concurrent.ArrayBlockingQueue;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.LinkedBlockingDeque;
import static org.hamcrest.CoreMatchers.containsString;
import static org.hamcrest.CoreMatchers.equalTo;
import static org.hamcrest.CoreMatchers.instanceOf;
import static org.hamcrest.CoreMatchers.notNullValue;
import static org.hamcrest.Matchers.empty;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertThat;
import static org.hamcrest.Matchers.hasSize;
public class TestTCPServerSource {
public static final String TEN_DELIMITED_RECORDS = "one\ntwo\nthree\nfour\nfive\nsix\nseven\neight\nnine\nten\n";
public static final String SYSLOG_RECORD = "<42>Mar 24 17:18:10 10.1.2.34 Got an error";
@Test
public void syslogRecords() {
Charset charset = Charsets.ISO_8859_1;
final TCPServerSourceConfig configBean = createConfigBean(charset);
TCPServerSource source = new TCPServerSource(configBean);
List<Stage.ConfigIssue> issues = new LinkedList<>();
EmbeddedChannel ch = new EmbeddedChannel(source.buildByteBufToMessageDecoderChain(issues).toArray(new ChannelHandler[0]));
ch.writeInbound(Unpooled.copiedBuffer(SYSLOG_RECORD + configBean.nonTransparentFramingSeparatorCharStr, charset));
assertSyslogRecord(ch);
assertFalse(ch.finishAndReleaseAll());
configBean.syslogFramingMode = SyslogFramingMode.OCTET_COUNTING;
EmbeddedChannel ch2 = new EmbeddedChannel(source.buildByteBufToMessageDecoderChain(issues).toArray(new ChannelHandler[0]));
ch2.writeInbound(Unpooled.copiedBuffer(SYSLOG_RECORD.length() + " " + SYSLOG_RECORD, charset));
assertSyslogRecord(ch2);
assertFalse(ch2.finishAndReleaseAll());
}
private void assertSyslogRecord(EmbeddedChannel ch) {
Object in1 = ch.readInbound();
assertThat(in1, notNullValue());
assertThat(in1, instanceOf(SyslogMessage.class));
SyslogMessage msg1 = (SyslogMessage) in1;
assertThat(msg1.getHost(), equalTo("10.1.2.34"));
assertThat(msg1.getRemainingMessage(), equalTo("Got an error"));
assertThat(msg1.getPriority(), equalTo(42));
assertThat(msg1.getFacility(), equalTo(5));
assertThat(msg1.getSeverity(), equalTo(2));
}
@Test
public void initMethod() throws Exception {
final TCPServerSourceConfig configBean = createConfigBean(Charsets.ISO_8859_1);
List<Stage.ConfigIssue> issues1 = initSourceAndGetIssues(configBean);
assertThat(issues1, hasSize(0));
// empty ports
configBean.ports = new LinkedList<>();
List<Stage.ConfigIssue> issues2 = initSourceAndGetIssues(configBean);
assertThat(issues2, hasSize(1));
assertThat(issues2.get(0).toString(), containsString(Errors.TCP_02.getCode()));
// invalid ports
// too large
configBean.ports = Arrays.asList("123456789");
List<Stage.ConfigIssue> issues3 = initSourceAndGetIssues(configBean);
assertThat(issues3, hasSize(1));
assertThat(issues3.get(0).toString(), containsString(Errors.TCP_03.getCode()));
// not a number
configBean.ports = Arrays.asList("abcd");
List<Stage.ConfigIssue> issues4 = initSourceAndGetIssues(configBean);
assertThat(issues4, hasSize(1));
assertThat(issues4.get(0).toString(), containsString(Errors.TCP_03.getCode()));
// start TLS config tests
configBean.ports = Arrays.asList("9876");
configBean.tlsEnabled = true;
configBean.tlsConfigBean.hasKeyStore = true;
List<Stage.ConfigIssue> issues5 = initSourceAndGetIssues(configBean);
assertThat(issues5, hasSize(1));
assertThat(issues5.get(0).toString(), containsString(TlsConfigErrors.TLS_02.getCode()));
configBean.tlsConfigBean.hasKeyStore = true;
configBean.tlsConfigBean.keyStoreFilePath = "non-existent-file-path";
List<Stage.ConfigIssue> issues6 = initSourceAndGetIssues(configBean);
assertThat(issues6, hasSize(1));
assertThat(issues6.get(0).toString(), containsString(TlsConfigErrors.TLS_01.getCode()));
File blankTempFile = File.createTempFile("blank", "txt");
blankTempFile.deleteOnExit();
configBean.tlsConfigBean.keyStoreFilePath = blankTempFile.getAbsolutePath();
List<Stage.ConfigIssue> issues7 = initSourceAndGetIssues(configBean);
assertThat(issues7, hasSize(1));
assertThat(issues7.get(0).toString(), containsString(TlsConfigErrors.TLS_21.getCode()));
// now, try with real keystore
String hostname = TLSTestUtils.getHostname();
File testDir = new File("target", UUID.randomUUID().toString()).getAbsoluteFile();
testDir.deleteOnExit();
final File keyStore = new File(testDir, "keystore.jks");
keyStore.deleteOnExit();
Assert.assertTrue(testDir.mkdirs());
final String keyStorePassword = "keystore";
KeyPair keyPair = TLSTestUtils.generateKeyPair();
Certificate cert = TLSTestUtils.generateCertificate("CN=" + hostname, keyPair, 30);
TLSTestUtils.createKeyStore(keyStore.toString(), keyStorePassword, "web", keyPair.getPrivate(), cert);
configBean.tlsConfigBean.keyStoreFilePath = keyStore.getAbsolutePath();
configBean.tlsConfigBean.keyStorePassword = "invalid-password";
List<Stage.ConfigIssue> issues9 = initSourceAndGetIssues(configBean);
assertThat(issues9, hasSize(1));
assertThat(issues9.get(0).toString(), containsString(TlsConfigErrors.TLS_21.getCode()));
// finally, a valid certificate/config
configBean.tlsConfigBean.keyStorePassword = keyStorePassword;
List<Stage.ConfigIssue> issues10 = initSourceAndGetIssues(configBean);
assertThat(issues10, hasSize(0));
}
@Test
public void runTextRecordsWithAck() throws StageException, IOException, ExecutionException, InterruptedException {
final Charset charset = Charsets.ISO_8859_1;
final TCPServerSourceConfig configBean = createConfigBean(charset);
configBean.dataFormat = DataFormat.TEXT;
configBean.tcpMode = TCPMode.DELIMITED_RECORDS;
configBean.recordSeparatorStr = "\n";
configBean.ports = NetworkUtils.getRandomPorts(1);
configBean.recordProcessedAckMessage = "record_ack_${record:id()}";
configBean.batchCompletedAckMessage = "batch_ack_${batchSize}";
final TCPServerSource source = new TCPServerSource(configBean);
final String outputLane = "lane";
final PushSourceRunner runner = new PushSourceRunner.Builder(TCPServerDSource.class, source)
.addOutputLane(outputLane)
.build();
final String[] expectedRecords = TEN_DELIMITED_RECORDS.split(configBean.recordSeparatorStr);
final int batchSize = expectedRecords.length;
final List<Record> records = new LinkedList<>();
runner.runInit();
EventLoopGroup workerGroup = new NioEventLoopGroup();
ChannelFuture channelFuture = startTcpClient(
configBean,
workerGroup,
TEN_DELIMITED_RECORDS.getBytes(charset),
true
);
runner.runProduce(new HashMap<>(), batchSize, output -> {
records.addAll(output.getRecords().get(outputLane));
runner.setStop();
});
runner.waitOnProduce();
// Wait until the connection is closed.
final Channel channel = channelFuture.channel();
TCPServerSourceClientHandler clientHandler = channel.pipeline().get(TCPServerSourceClientHandler.class);
final List<String> responses = new LinkedList<>();
for (int i = 0; i < batchSize + 1; i++) {
// one for each record, plus one for the batch
responses.add(clientHandler.getResponse());
}
channel.close();
workerGroup.shutdownGracefully();
assertThat(records, hasSize(batchSize));
for (int i = 0; i < records.size(); i++) {
// validate the output record value
assertThat(records.get(i).get("/text").getValueAsString(), equalTo(expectedRecords[i]));
// validate the record-level ack
assertThat(responses.get(i), equalTo(String.format("record_ack_%s", records.get(i).getHeader().getSourceId())));
}
// validate the batch-level ack
assertThat(responses.get(10), equalTo(String.format("batch_ack_%d", batchSize)));
}
@Test
public void errorHandling() throws StageException, IOException, ExecutionException, InterruptedException {
final Charset charset = Charsets.ISO_8859_1;
final TCPServerSourceConfig configBean = createConfigBean(charset);
configBean.dataFormat = DataFormat.JSON;
configBean.tcpMode = TCPMode.DELIMITED_RECORDS;
configBean.recordSeparatorStr = "\n";
configBean.ports = NetworkUtils.getRandomPorts(1);
final TCPServerSource source = new TCPServerSource(configBean);
final String outputLane = "lane";
final PushSourceRunner toErrorRunner = new PushSourceRunner.Builder(TCPServerDSource.class, source)
.addOutputLane(outputLane)
.setOnRecordError(OnRecordError.TO_ERROR)
.build();
final List<Record> records = new LinkedList<>();
final List<Record> errorRecords = new LinkedList<>();
runAndCollectRecords(
toErrorRunner,
configBean,
records,
errorRecords,
1,
outputLane,
"{\"invalid_json\": yes}\n".getBytes(charset),
true,
false
);
assertThat(records, empty());
assertThat(errorRecords, hasSize(1));
assertThat(
errorRecords.get(0).getHeader().getErrorCode(),
equalTo(com.streamsets.pipeline.lib.parser.Errors.DATA_PARSER_04.getCode())
);
final PushSourceRunner discardRunner = new PushSourceRunner.Builder(TCPServerDSource.class, source)
.addOutputLane(outputLane)
.setOnRecordError(OnRecordError.DISCARD)
.build();
records.clear();
errorRecords.clear();
configBean.ports = NetworkUtils.getRandomPorts(1);
runAndCollectRecords(
discardRunner,
configBean,
records,
errorRecords,
1,
outputLane,
"{\"invalid_json\": yes}\n".getBytes(charset),
true,
false
);
assertThat(records, empty());
assertThat(errorRecords, empty());
configBean.ports = NetworkUtils.getRandomPorts(1);
final PushSourceRunner stopPipelineRunner = new PushSourceRunner.Builder(TCPServerDSource.class, source)
.addOutputLane(outputLane)
.setOnRecordError(OnRecordError.STOP_PIPELINE)
.build();
records.clear();
errorRecords.clear();
try {
runAndCollectRecords(
stopPipelineRunner,
configBean,
records,
errorRecords,
1,
outputLane,
"{\"invalid_json\": yes}\n".getBytes(charset),
true,
true
);
Assert.fail("ExecutionException should have been thrown");
} catch (ExecutionException e) {
assertThat(e.getCause(), instanceOf(RuntimeException.class));
final RuntimeException runtimeException = (RuntimeException) e.getCause();
assertThat(runtimeException.getCause(), instanceOf(StageException.class));
final StageException stageException = (StageException) runtimeException.getCause();
assertThat(stageException.getErrorCode().getCode(), equalTo(Errors.TCP_06.getCode()));
}
}
private void runAndCollectRecords(
PushSourceRunner runner,
TCPServerSourceConfig configBean,
List<Record> records,
List<Record> errorRecords,
int batchSize,
String outputLane,
byte[] data,
boolean randomlySlice,
boolean runEmptyProduceAtEnd
) throws StageException, InterruptedException, ExecutionException {
runner.runInit();
EventLoopGroup workerGroup = new NioEventLoopGroup();
runner.runProduce(new HashMap<>(), batchSize, output -> {
records.addAll(output.getRecords().get(outputLane));
if (!runEmptyProduceAtEnd) {
runner.setStop();
}
});
ChannelFuture channelFuture = startTcpClient(
configBean,
workerGroup,
data,
randomlySlice
);
// Wait until the connection is closed.
channelFuture.channel().closeFuture().sync();
// wait for the push source runner produce to complete
runner.waitOnProduce();
errorRecords.addAll(runner.getErrorRecords());
if (runEmptyProduceAtEnd) {
runner.runProduce(new HashMap<>(), 0, output -> {
runner.setStop();
});
runner.waitOnProduce();
}
runner.runDestroy();
workerGroup.shutdownGracefully();
}
private ChannelFuture startTcpClient(
TCPServerSourceConfig configBean,
EventLoopGroup workerGroup,
byte[] data,
boolean randomlySlice
) throws
InterruptedException {
ChannelFuture channelFuture;
Bootstrap bootstrap = new Bootstrap();
bootstrap.group(workerGroup);
bootstrap.channel(NioSocketChannel.class);
bootstrap.option(ChannelOption.SO_KEEPALIVE, true);
bootstrap.handler(new ChannelInitializer() {
@Override
protected void initChannel(Channel ch) throws Exception {
ch.pipeline().addLast(new TCPServerSourceClientHandler(randomlySlice, data));
}
});
// Start the client.
channelFuture = bootstrap.connect("localhost", Integer.parseInt(configBean.ports.get(0))).sync();
return channelFuture;
}
private static class TCPServerSourceClientHandler extends ChannelInboundHandlerAdapter {
private final boolean randomlySlice;
private final byte[] data;
private final BlockingQueue<String> responses = new LinkedBlockingDeque<>();
private TCPServerSourceClientHandler(boolean randomlySlice, byte[] data) {
this.randomlySlice = randomlySlice;
this.data = data;
}
@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
ByteBuf buf = (ByteBuf) msg;
responses.add(buf.toString(com.google.common.base.Charsets.UTF_8));
}
@Override
public void channelActive(ChannelHandlerContext ctx) throws Exception {
super.channelActive(ctx);
if (randomlySlice) {
for (List<Byte> slice : NetTestUtils.getRandomByteSlices(data)) {
ctx.writeAndFlush(Unpooled.copiedBuffer(Bytes.toArray(slice)));
}
} else {
ctx.writeAndFlush(Unpooled.copiedBuffer(data));
}
}
private String getResponse() throws InterruptedException {
return responses.take();
}
}
private static List<Stage.ConfigIssue> initSourceAndGetIssues(TCPServerSourceConfig configBean) throws
StageException {
TCPServerSource source = new TCPServerSource(configBean);
PushSourceRunner runner = new PushSourceRunner.Builder(TCPServerDSource.class, source)
.addOutputLane("lane")
.setOnRecordError(OnRecordError.TO_ERROR)
.build();
return runner.runValidateConfigs();
}
protected static TCPServerSourceConfig createConfigBean(Charset charset) {
TCPServerSourceConfig config = new TCPServerSourceConfig();
config.batchSize = 10;
config.tlsEnabled = false;
config.numThreads = 1;
config.syslogCharset = charset.name();
config.tcpMode = TCPMode.SYSLOG;
config.syslogFramingMode= SyslogFramingMode.NON_TRANSPARENT_FRAMING;
config.nonTransparentFramingSeparatorCharStr = "\n";
config.maxMessageSize = 4096;
config.ports = Arrays.asList("9876");
return config;
}
}